import typing
from typing import Any, Callable, List, Tuple, Union

import numpy as np
import os, sys
import paddle
from .abc_interpreter import Interpreter
from ..data_processor.readers import preprocess_image, read_image, restore_image, preprocess_inputs
from ..data_processor.visualizer import visualize_overlay


class SmoothGradInterpreter(Interpreter):
    """
    Smooth Gradients Interpreter.

    Smooth Gradients method solves the problem of meaningless local variations in partial derivatives
    by adding random noise to the inputs multiple times and take the average of the
    gradients.

    More details regarding the Smooth Gradients method can be found in the original paper:
    http://arxiv.org/pdf/1706.03825.pdf
    """

    def __init__(self,
                 paddle_model,
                 use_cuda=True,
                 model_input_shape=[3, 224, 224]):
        """
        Initialize the SmoothGradInterpreter.

        Args:
            paddle_model (callable): A user-defined function that gives access to model predictions.
                    It takes the following arguments:

                    - data: Data input.
                    and outputs predictions. See the example at the end of ``interpret()``.
            trained_model_path (str): The pretrained model directory.
            use_cuda (bool, optional): Whether or not to use cuda. Default: True
            model_input_shape (list, optional): The input shape of the model. Default: [3, 224, 224]
        """
        Interpreter.__init__(self)
        self.paddle_model = paddle_model
        self.use_cuda = use_cuda
        self.model_input_shape = model_input_shape
        self.data_type = 'float32'
        self.paddle_prepared = False

    def interpret(self,
                  inputs,
                  labels=None,
                  noise_amount=0.1,
                  n_samples=50,
                  visual=True,
                  save_path=None):
        """
        Main function of the interpreter.

        Args:
            inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images.
            labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels should be equal to the number of images. If None, the most likely label for each image will be used. Default: None
            noise_amount (float, optional): Noise level of added noise to the image.
                                            The std of Guassian random noise is noise_amount * (x_max - x_min). Default: 0.1
            n_samples (int, optional): The number of new images generated by adding noise. Default: 50
            visual (bool, optional): Whether or not to visualize the processed image. Default: True
            save_path (str or list of strs or None, optional): The filepath(s) to save the processed image(s). If None, the image will not be saved. Default: None

        :return: interpretations/gradients for each image
        :rtype: numpy.ndarray

        Example::

            import interpretdl as it
            def paddle_model(data):
                import paddle.fluid as fluid
                class_num = 1000
                model = ResNet50()
                logits = model.net(input=image_input, class_dim=class_num)
                probs = fluid.layers.softmax(logits, axis=-1)
                return probs
            sg = it.SmoothGradInterpreter(paddle_model, "assets/ResNet50_pretrained")
            gradients = sg.interpret(img_path, visual=True, save_path='assets/sg_test.jpg')
        """

        imgs, data, save_path = preprocess_inputs(inputs, save_path,
                                                  self.model_input_shape)

        data_type = np.array(data).dtype
        self.data_type = data_type

        if not self.paddle_prepared:
            self._paddle_prepare()

        if labels is None:
            _, preds = self.predict_fn(data, None)
            labels = preds

        labels = np.array(labels).reshape((len(imgs), 1))

        max_axis = tuple(np.arange(1, data.ndim))
        stds = noise_amount * (
            np.max(data, axis=max_axis) - np.min(data, axis=max_axis))

        total_gradients = np.zeros_like(data)
        for i in range(n_samples):
            noise = np.concatenate([
                np.float32(
                    np.random.normal(0.0, stds[j], (1, ) + tuple(d.shape)))
                for j, d in enumerate(data)
            ])
            data_noised = data + noise
            gradients, _ = self.predict_fn(data_noised, labels)
            total_gradients += gradients

        avg_gradients = total_gradients / n_samples

        for i in range(len(imgs)):
            visualize_overlay(avg_gradients[i], imgs[i], visual, save_path[i])

        return avg_gradients

    def _paddle_prepare(self, predict_fn=None):
        if predict_fn is None:
            if self.use_cuda:
                paddle.set_device('gpu:0')
            else:
                paddle.set_device('cpu')

            self.paddle_model.train()

            for n, v in self.paddle_model.named_sublayers():
                if "batchnorm" in v.__class__.__name__.lower():
                    v._use_global_stats = True
                if "dropout" in v.__class__.__name__.lower():
                    v.p = 0

            def predict_fn(data, labels):
                data = paddle.to_tensor(data)
                data.stop_gradient = False
                out = self.paddle_model(data)
                out = paddle.nn.functional.softmax(out, axis=1)
                preds = paddle.argmax(out, axis=1)
                if labels is None:
                    labels = preds.numpy()
                labels_onehot = paddle.nn.functional.one_hot(
                    paddle.to_tensor(labels), num_classes=out.shape[1])
                target = paddle.sum(out * labels_onehot, axis=1)
                gradients = paddle.grad(outputs=[target], inputs=[data])[0]
                return gradients.numpy(), labels

        self.predict_fn = predict_fn
        self.paddle_prepared = True
